import plotly.graph_objects as go
import pandas as pd
# Assuming df_sankey_models_transformed already exists with 'gm', 'mit', 'oa'
# 1. Create unique labels
labels = list(pd.concat([
df_sankey_models_transformed['gm'],
df_sankey_models_transformed['mit'],
df_sankey_models_transformed['oa']
]).unique())
# 2. Map labels to indices
label_map = {label: idx for idx, label in enumerate(labels)}
# 3. Define sources, targets, and values
sources = df_sankey_models_transformed['gm'].map(label_map).tolist() + \
df_sankey_models_transformed['mit'].map(label_map).tolist()
targets = df_sankey_models_transformed['mit'].map(label_map).tolist() + \
df_sankey_models_transformed['oa'].map(label_map).tolist()
values = df_sankey_models_transformed['gm_count'].tolist() + \
df_sankey_models_transformed['oa_count'].tolist()
# 4. Generate pastel colors
import random
def pastel_color():
r = lambda: random.randint(100, 255)
return f'rgba({r()},{r()},{r()},0.6)'
label_to_color = {label: pastel_color() for label in labels}
node_color_list = [label_to_color[label] for label in labels]
# 5. Set link colors by source node's color (first half) and middle node (second half)
link_colors = [label_to_color[df_sankey_models_transformed['gm'].iloc[i]]
for i in range(len(df_sankey_models_transformed))] + \
[label_to_color[df_sankey_models_transformed['mit'].iloc[i]]
for i in range(len(df_sankey_models_transformed))]
# 7. Create the figure
fig = go.Figure(data=[go.Sankey(
node=dict(
pad=15,
thickness=20,
line=dict(color="black", width=0.9),
label=labels,
color=node_color_list,
align="left"
),
link=dict(
source=sources,
target=targets,
value=values,
color=link_colors,
customdata=df_sankey_models_transformed['nameyear'].tolist() * 2,
hovertemplate='%{customdata}<extra></extra>'
)
)])
fig.update_layout(
title_text="Sankey Diagram for model types",
font_size=16.5,
width=1500,
height=800,
hovermode='x'
)
fig.show()
fig.write_html("C:/Users/U727148/Latent_Variable_Supplement/sankey_models_plot.html", include_plotlyjs="cdn")